Alzheimer's Disease Detection¶
Globally, over 55 million people are affected by Alzheimer’s, with no cure available. Early diagnosis is crucial to slow progression and improve patient quality of life.
This project aimed to develop a CNN for classifying Alzheimer’s severity into four ordinal classes: none, very mild, mild, and moderate, using more than 85,000 MRI scans from the OASIS study. The class distribution shows significant imbalance: 67,222 scans (none), 13,725 (very mild), 5,002 (mild), and 488 (moderate).
1. Importing Packages & Downloading Data¶
1.1 Importing Packages¶
!pip install tensorflow
Requirement already satisfied: tensorflow in /opt/anaconda3/lib/python3.12/site-packages (2.18.0) Requirement already satisfied: absl-py>=1.0.0 in /opt/anaconda3/lib/python3.12/site-packages (from tensorflow) (2.1.0) Requirement already satisfied: astunparse>=1.6.0 in /opt/anaconda3/lib/python3.12/site-packages (from tensorflow) (1.6.3) Requirement already satisfied: flatbuffers>=24.3.25 in /opt/anaconda3/lib/python3.12/site-packages (from tensorflow) (24.3.25) Requirement already satisfied: gast!=0.5.0,!=0.5.1,!=0.5.2,>=0.2.1 in /opt/anaconda3/lib/python3.12/site-packages (from tensorflow) (0.6.0) Requirement already satisfied: google-pasta>=0.1.1 in /opt/anaconda3/lib/python3.12/site-packages (from tensorflow) (0.2.0) Requirement already satisfied: libclang>=13.0.0 in /opt/anaconda3/lib/python3.12/site-packages (from tensorflow) (18.1.1) Requirement already satisfied: opt-einsum>=2.3.2 in /opt/anaconda3/lib/python3.12/site-packages (from tensorflow) (3.4.0) Requirement already satisfied: packaging in /opt/anaconda3/lib/python3.12/site-packages (from tensorflow) (24.1) Requirement already satisfied: protobuf!=4.21.0,!=4.21.1,!=4.21.2,!=4.21.3,!=4.21.4,!=4.21.5,<6.0.0dev,>=3.20.3 in /opt/anaconda3/lib/python3.12/site-packages (from tensorflow) (4.25.3) Requirement already satisfied: requests<3,>=2.21.0 in /opt/anaconda3/lib/python3.12/site-packages (from tensorflow) (2.32.3) Requirement already satisfied: setuptools in /opt/anaconda3/lib/python3.12/site-packages (from tensorflow) (75.1.0) Requirement already satisfied: six>=1.12.0 in /opt/anaconda3/lib/python3.12/site-packages (from tensorflow) (1.16.0) Requirement already satisfied: termcolor>=1.1.0 in /opt/anaconda3/lib/python3.12/site-packages (from tensorflow) (2.5.0) Requirement already satisfied: typing-extensions>=3.6.6 in /opt/anaconda3/lib/python3.12/site-packages (from tensorflow) (4.11.0) Requirement already satisfied: wrapt>=1.11.0 in /opt/anaconda3/lib/python3.12/site-packages (from tensorflow) (1.14.1) Requirement already satisfied: grpcio<2.0,>=1.24.3 in /opt/anaconda3/lib/python3.12/site-packages (from tensorflow) (1.68.0) Requirement already satisfied: tensorboard<2.19,>=2.18 in /opt/anaconda3/lib/python3.12/site-packages (from tensorflow) (2.18.0) Requirement already satisfied: keras>=3.5.0 in /opt/anaconda3/lib/python3.12/site-packages (from tensorflow) (3.7.0) Requirement already satisfied: numpy<2.1.0,>=1.26.0 in /opt/anaconda3/lib/python3.12/site-packages (from tensorflow) (1.26.4) Requirement already satisfied: h5py>=3.11.0 in /opt/anaconda3/lib/python3.12/site-packages (from tensorflow) (3.11.0) Requirement already satisfied: ml-dtypes<0.5.0,>=0.4.0 in /opt/anaconda3/lib/python3.12/site-packages (from tensorflow) (0.4.1) Requirement already satisfied: wheel<1.0,>=0.23.0 in /opt/anaconda3/lib/python3.12/site-packages (from astunparse>=1.6.0->tensorflow) (0.44.0) Requirement already satisfied: rich in /opt/anaconda3/lib/python3.12/site-packages (from keras>=3.5.0->tensorflow) (13.7.1) Requirement already satisfied: namex in /opt/anaconda3/lib/python3.12/site-packages (from keras>=3.5.0->tensorflow) (0.0.8) Requirement already satisfied: optree in /opt/anaconda3/lib/python3.12/site-packages (from keras>=3.5.0->tensorflow) (0.13.1) Requirement already satisfied: charset-normalizer<4,>=2 in /opt/anaconda3/lib/python3.12/site-packages (from requests<3,>=2.21.0->tensorflow) (3.3.2) Requirement already satisfied: idna<4,>=2.5 in /opt/anaconda3/lib/python3.12/site-packages (from requests<3,>=2.21.0->tensorflow) (3.7) Requirement already satisfied: urllib3<3,>=1.21.1 in /opt/anaconda3/lib/python3.12/site-packages (from requests<3,>=2.21.0->tensorflow) (2.2.3) Requirement already satisfied: certifi>=2017.4.17 in /opt/anaconda3/lib/python3.12/site-packages (from requests<3,>=2.21.0->tensorflow) (2024.8.30) Requirement already satisfied: markdown>=2.6.8 in /opt/anaconda3/lib/python3.12/site-packages (from tensorboard<2.19,>=2.18->tensorflow) (3.4.1) Requirement already satisfied: tensorboard-data-server<0.8.0,>=0.7.0 in /opt/anaconda3/lib/python3.12/site-packages (from tensorboard<2.19,>=2.18->tensorflow) (0.7.2) Requirement already satisfied: werkzeug>=1.0.1 in /opt/anaconda3/lib/python3.12/site-packages (from tensorboard<2.19,>=2.18->tensorflow) (3.0.3) Requirement already satisfied: MarkupSafe>=2.1.1 in /opt/anaconda3/lib/python3.12/site-packages (from werkzeug>=1.0.1->tensorboard<2.19,>=2.18->tensorflow) (2.1.3) Requirement already satisfied: markdown-it-py>=2.2.0 in /opt/anaconda3/lib/python3.12/site-packages (from rich->keras>=3.5.0->tensorflow) (2.2.0) Requirement already satisfied: pygments<3.0.0,>=2.13.0 in /opt/anaconda3/lib/python3.12/site-packages (from rich->keras>=3.5.0->tensorflow) (2.15.1) Requirement already satisfied: mdurl~=0.1 in /opt/anaconda3/lib/python3.12/site-packages (from markdown-it-py>=2.2.0->rich->keras>=3.5.0->tensorflow) (0.1.0)
import numpy as np
import pandas as pd
import keras
import matplotlib.pyplot as plt
import re
import os
import random
import tensorflow as tf
import plotly.express as px
from keras.models import Sequential
from PIL import Image
from keras.layers import Conv2D,Flatten,Dense,Dropout,BatchNormalization,MaxPooling2D
from sklearn.preprocessing import OneHotEncoder, label_binarize
from tensorflow.keras.layers import Conv2D, GlobalAveragePooling2D, Dense
from sklearn.model_selection import train_test_split
from sklearn.metrics import auc, average_precision_score, confusion_matrix, roc_auc_score, f1_score, confusion_matrix, precision_recall_fscore_support
from tensorflow.keras.applications import EfficientNetB0, EfficientNetV2B1
from tensorflow.keras import layers, models
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.callbacks import EarlyStopping
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.models import clone_model
from matplotlib.colors import LogNorm, LinearSegmentedColormap
from PIL import Image
from scipy.stats import skew
from tqdm import tqdm
1.2 Downloading Data¶
import os
base_path = "Downloads/Data"
# Four categories
non_demented = []
very_mild_demented = []
mild_demented = []
moderate_demented = []
for dirname, _, filenames in os.walk(os.path.join(base_path, "Non Demented")):
for filename in filenames:
non_demented.append(os.path.join(dirname, filename))
for dirname, _, filenames in os.walk(os.path.join(base_path, "Very mild Dementia")):
for filename in filenames:
very_mild_demented.append(os.path.join(dirname, filename))
for dirname, _, filenames in os.walk(os.path.join(base_path, "Mild Dementia")):
for filename in filenames:
mild_demented.append(os.path.join(dirname, filename))
for dirname, _, filenames in os.walk(os.path.join(base_path, "Moderate Dementia")):
for filename in filenames:
moderate_demented.append(os.path.join(dirname, filename))
print("Non Demented:", len(non_demented))
print("Very Mild Demented:", len(very_mild_demented))
print("Mild Demented:", len(mild_demented))
print("Moderate Demented:", len(moderate_demented))
Non Demented: 67222 Very Mild Demented: 13725 Mild Demented: 5002 Moderate Demented: 488
print(len(non_demented))
print(len(very_mild_demented))
print(len(mild_demented))
print(len(moderate_demented))
67222 13725 5002 488
1.3 Visualizing Images¶
Below we visualized example images for each class.
def display_images_with_text(file_paths, category_name, endings=['150', '151', '152']):
plt.figure(figsize=(15, 5))
plt.suptitle(f"Images from {category_name}", fontsize=16)
for ending in endings:
matching_files = [img for img in file_paths if img.endswith(ending + '.jpg')]
for i in range(min(3, len(matching_files))):
img_path = matching_files[i]
img = Image.open(img_path)
plt.subplot(1, 3, i + 1)
plt.imshow(img)
plt.axis('off')
# Add text indicating the category
plt.text(0, -10, f"{category_name.split()[0]} {i + 1}", color='white', fontsize=12, weight='bold', ha='left', va='bottom', bbox=dict(facecolor='black', alpha=0.7))
plt.show()
display_images_with_text(non_demented, "Non Demented")
display_images_with_text(very_mild_demented, "Very Mild Demented")
display_images_with_text(mild_demented, "Mild Demented")
display_images_with_text(moderate_demented, "Moderate Demented")
2. Exploratory Data Analysis (EDA)¶
def get_info_from_filename(filename):
pattern = re.compile('OAS1_(\d+)_MR(\d+)_mpr-(\d+)_(\d+).jpg')
match = pattern.match(filename)
patient_id = match.group(1)
mr_id = match.group(2)
scan_id = match.group(3)
layer_id = match.group(4)
return patient_id, mr_id, scan_id, layer_id
<>:2: SyntaxWarning: invalid escape sequence '\d'
<>:2: SyntaxWarning: invalid escape sequence '\d'
/var/folders/2r/t9m11pks08ldfjrlpqfbk6jm0000gn/T/ipykernel_92652/1550341169.py:2: SyntaxWarning: invalid escape sequence '\d'
pattern = re.compile('OAS1_(\d+)_MR(\d+)_mpr-(\d+)_(\d+).jpg')
2.1.2 path, label, parient_id, mr_id, scan_id, layer_id¶
def create_ref_df(dataset_path):
paths, labels = [], []
patient_ids, mr_ids, scan_ids, layer_ids = [], [], [], []
for folder in os.listdir(dataset_path):
for file in os.listdir(os.path.join(dataset_path, folder)):
patient_id, mr_id, scan_id, layer_id = get_info_from_filename(file)
paths.append(os.path.join(dataset_path, folder, file))
labels.append(folder)
patient_ids.append(patient_id)
mr_ids.append(mr_id)
scan_ids.append(scan_id)
layer_ids.append(layer_id)
ref_df = pd.DataFrame({
'path': paths,
'label': labels,
'patient_id': patient_ids,
'mr_id': mr_ids,
'scan_id': scan_ids,
'layer_id': layer_ids
})
ref_df = ref_df.astype({
'path': 'string',
'label': 'string',
'patient_id': 'int64',
'mr_id': 'int64',
'scan_id': 'int64',
'layer_id': 'int64'
})
return ref_df
ref_df = create_ref_df('Downloads/Data')
2.2 Loading Images with Labels and Paths¶
def load_images(ref_df):
labels = []
images = []
paths = []
for idx, row in tqdm(ref_df.iterrows(), total=ref_df.shape[0]):
images.append(np.array(Image.open(row['path']).convert('L')))
labels.append(row['label'])
paths.append(row['path'])
return images, labels, paths
images, labels, paths = load_images(ref_df)
100%|███████████████████████████████████| 86437/86437 [01:18<00:00, 1098.37it/s]
2.3 Extracting Image Statistics¶
The following statistics were extracted: mean, std, width, height, skew
def get_image_stats(images, labels, paths):
means, stds, widths, heights = [], [], [], []
skewnesses = []
for image in tqdm(images):
means.append(np.mean(image))
stds.append(np.std(image))
widths.append(image.shape[0])
heights.append(image.shape[1])
# Calculate skewness
image_hist = np.histogram(image.flatten())[0]
skewnesses.append(skew(image_hist))
image_stats = pd.DataFrame({
'mean': means,
'std': stds,
'width': widths,
'height': heights,
'skew': skewnesses
})
image_stats['label'] = labels
image_stats['path'] = paths
return image_stats
image_stats = get_image_stats(images, labels, paths)
100%|████████████████████████████████████| 86437/86437 [01:31<00:00, 948.85it/s]
2.4 Label Distribution¶
frequencies = [len(non_demented), len(very_mild_demented), len(moderate_demented), len(mild_demented),]
class_names = ["non_demented", "very_mild_demented", "moderate_demented", "mild_demented"]
sorted_indices = np.argsort(frequencies)[::-1]
class_names = np.array(class_names)[sorted_indices]
frequencies = np.array(frequencies)[sorted_indices]
navy_blue_palette = LinearSegmentedColormap.from_list("navy_blue_palette", ["lightblue", "cornflowerblue", "royalblue", "navy"], N=4)
fig, ax = plt.subplots()
bp = ax.barh(np.arange(len(class_names)), frequencies, color=navy_blue_palette(np.linspace(0, 1, len(class_names))))
ax.set_ylim(-0.5, len(class_names)-0.5)
ax.set_yticks(np.arange(len(class_names)))
ax.set_yticklabels(class_names)
ax.axis('off')
for i, frequency in enumerate(frequencies):
ax.text(frequency + 0.1, i, frequency, ha='left', va='center', rotation=-20)
ax.axvline(x=0, linestyle='--', color='black')
handles = [plt.Rectangle((0, 0), 1, 1, color=navy_blue_palette(i)) for i in range(4)]
ax.legend(handles, class_names, loc="upper right", title="Classes")
plt.title("Observations per Class")
plt.show()
The above bar chart indicates a very bad class imbalance in the dataset into four classes:
non_demented - This class is highly imbalanced: it has 67,222 observations, which is the overwhelming majority of the dataset. The large disparity in the sizes of these two classes would increase the risk that the model will learn to predict this class more often.
Very Mild Demented : This class only has 13,725 observations, making this class significantly smaller than the previous (non-demented) class while providing considerable data covering the remaining two classes.
Mild Demented : This class has 5,002 observations, resulting in a further reduction in the representation. This may cause difficulty in correctly identifying the cases with mild dementia.
Moderate Demented : This is the smallest class with only 488 observations and thus highly underrepresented. Models may not learn patterns for this class well, resulting in poor performance in identifying moderate dementia.
2.5 Mean, SD and Skewness of Images¶
We visualized the mean, standard deviation and skewness of the images with boxplots.
px.box(image_stats, y='mean', x='label')
px.box(image_stats, y='std', x='label')
px.box(image_stats, y='skew', x='label')
When looking at scatter plots and box charts, an interesting trend arises: with the progression of dementia, the statistics (mean, standard deviation, skew) of the images accumulate into slimmer bands.
In addition, skewness appears to increase with severity. Though not definitive, these findings indicate that utilizing these statistical features may help enable and accelerate the training of a deep learning model.
3. Data Pre-Processing¶
3.1 Splitting Off a Test Set¶
random.seed(42)
moderate_demented_train, moderate_demented_test = train_test_split(
moderate_demented, test_size=0.2, random_state=42
)
mild_demented_train, mild_demented_test = train_test_split(
mild_demented, test_size=0.2, random_state=42
)
very_mild_demented_train, very_mild_demented_test = train_test_split(
very_mild_demented, test_size=0.2, random_state=42
)
non_demented_train, non_demented_test = train_test_split(
non_demented, test_size=0.2, random_state=42
)
The dataset for Alzheimer's classification was notably imbalanced, with the minority class representing moderate dementia comprising a mere 0.007% of the majority class, which consisted of individuals classified as non-demented. This stark contrast in representation is not just a statistical anomaly but rather a reflection of the real-world prevalence of Alzheimer's disease, particularly in its various stages of severity. As such, the challenges presented by this imbalance are profound, particularly when it comes to training sophisticated models like neural networks.
When training these models, the impact of such an imbalance cannot be overstated. It creates a scenario where traditional metrics of performance can be misleading. For instance, a model might achieve seemingly high accuracy rates, but this figure is often artificially inflated due to the overwhelming number of examples from the majority class. Consequently, the model may fail to generalize effectively, leading to subpar performance when it encounters instances from the minority class, such as those with moderate dementia. This is particularly concerning as it undermines the very purpose of the classification system, which is to accurately identify and differentiate among the varying levels of cognitive impairment.
3.2.1 Under-/Oversampling Train Data¶
print(len(non_demented_train))
print(len(very_mild_demented_train))
print(len(mild_demented_train))
print(len(moderate_demented_train))
53777 10980 4001 390
target_samples = 5000
moderate_demented_samp = random.choices(moderate_demented_train, k=target_samples)
mild_demented_samp = random.choices(mild_demented_train, k=target_samples)
very_mild_demented_samp = random.sample(very_mild_demented_train, k=target_samples)
non_demented_samp = random.sample(non_demented_train, k=target_samples)
print(len(non_demented_samp))
print(len(very_mild_demented_samp))
print(len(mild_demented_samp))
print(len(moderate_demented_samp))
5000 5000 5000 5000
3.2.2 Undersampling Test Data¶
print(len(non_demented_test))
print(len(very_mild_demented_test))
print(len(mild_demented_test))
print(len(moderate_demented_test))
13445 2745 1001 98
target_samples = 640
mild_demented_test = random.sample(mild_demented_test, k=target_samples)
very_mild_demented_test = random.sample(very_mild_demented_test, k=target_samples)
non_demented_test = random.sample(non_demented_test, k=target_samples)
3.3 Merging, Shaping, Splitting¶
3.3.1 One-Hot Encoder¶
We needed a function to one-hot encode the class labels.
# One-hot encoder for class labels
encoder = OneHotEncoder()
encoder.fit([[0],[1],[2],[3]])
OneHotEncoder()In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
OneHotEncoder()
3.3.2 Train Data¶
Each class's train images became the desired shape: 128x128x3. We select 128 pixels so that we can spend less time on computation. It all got thrown into the same 'data' object. The class labels were converted to one-hot encoding and were saved into the result object.
After this, we split this data into a training set (80%) and a validation set (20%).
data = []
result = []
for path in non_demented_samp:
img = Image.open(path)
img = img.resize((128,128))
img = np.array(img)
if(img.shape == (128,128,3)):
data.append(np.array(img))
result.append(encoder.transform([[0]]).toarray())
for path in very_mild_demented_samp:
img = Image.open(path)
img = img.resize((128,128))
img = np.array(img)
if(img.shape == (128,128,3)):
data.append(np.array(img))
result.append(encoder.transform([[1]]).toarray())
for path in mild_demented_samp:
img = Image.open(path)
img = img.resize((128,128))
img = np.array(img)
if(img.shape == (128,128,3)):
data.append(np.array(img))
result.append(encoder.transform([[2]]).toarray())
for path in moderate_demented_samp:
img = Image.open(path)
img = img.resize((128,128))
img = np.array(img)
if(img.shape == (128,128,3)):
data.append(np.array(img))
result.append(encoder.transform([[3]]).toarray())
data = np.array(data)
data.shape
(20000, 128, 128, 3)
result = np.array(result)
result = result.reshape((data.shape[0],4))
result.shape
(20000, 4)
x_train,x_val,y_train,y_val = train_test_split(data,result, test_size=0.20, shuffle=True, random_state=42)
3.3.3 Test Data¶
Also for the test data, we transformed the images from each class into the preferred shape (128x128x3) and merged the categories together into the 'data_test' object. Again, the class labels were transformed into the one-hot encoded format and stored into the 'result_test' object.
data_test = []
result_test = []
for path in non_demented_test:
img = Image.open(path)
img = img.resize((128,128))
img = np.array(img)
if(img.shape == (128,128,3)):
data_test.append(np.array(img))
result_test.append(encoder.transform([[0]]).toarray())
for path in very_mild_demented_test:
img = Image.open(path)
img = img.resize((128,128))
img = np.array(img)
if(img.shape == (128,128,3)):
data_test.append(np.array(img))
result_test.append(encoder.transform([[1]]).toarray())
for path in mild_demented_test:
img = Image.open(path)
img = img.resize((128,128))
img = np.array(img)
if(img.shape == (128,128,3)):
data_test.append(np.array(img))
result_test.append(encoder.transform([[2]]).toarray())
for path in moderate_demented_test:
img = Image.open(path)
img = img.resize((128,128))
img = np.array(img)
if(img.shape == (128,128,3)):
data_test.append(np.array(img))
result_test.append(encoder.transform([[3]]).toarray())
# Transform data to numpy array
data_test = np.array(data_test)
data_test.shape
(2018, 128, 128, 3)
# Transform labels to numpy array
result_test = np.array(result_test)
result_test = result_test.reshape((data_test.shape[0],4))
result_test.shape
(2018, 4)
We changed the names of 'data_test' to 'x_test' and 'result_test' to 'y_test' for the sake of consistency.
# Change names to x_test and y_test
x_test = data_test
y_test = result_test
The data pre-processing thus resulted in the following data objects:
- x_train: train data
- y_train: train labels
- x_val: validation data (used in training/modeling)
- y_val: validation labels (used in training/modeling)
- x_test: test data (evaluation after modeling)
- y_test: test labels (evaluation after modeling)
# Convert labels to integers
y_train_int = np.argmax(y_train, axis=1)
y_val_int = np.argmax(y_val, axis=1)
pretrained_base_EfficientNetB0 = EfficientNetB0(weights='imagenet', include_top=False, input_shape=(128, 128, 3))
pretrained_base_EfficientNetB0.trainable = True
model_EfficientNetB0 = models.Sequential([
pretrained_base_EfficientNetB0,
layers.BatchNormalization(),
layers.GlobalAveragePooling2D(),
layers.Dense(512, activation='relu'),
layers.Dropout(0.5),
layers.Dense(4, activation='softmax')
])
model_EfficientNetB0.summary()
Model: "sequential"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓ ┃ Layer (type) ┃ Output Shape ┃ Param # ┃ ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩ │ efficientnetb0 (Functional) │ (None, 4, 4, 1280) │ 4,049,571 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ batch_normalization │ (None, 4, 4, 1280) │ 5,120 │ │ (BatchNormalization) │ │ │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ global_average_pooling2d │ (None, 1280) │ 0 │ │ (GlobalAveragePooling2D) │ │ │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ dense (Dense) │ (None, 512) │ 655,872 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ dropout (Dropout) │ (None, 512) │ 0 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ dense_1 (Dense) │ (None, 4) │ 2,052 │ └─────────────────────────────────┴────────────────────────┴───────────────┘
Total params: 4,712,615 (17.98 MB)
Trainable params: 4,668,032 (17.81 MB)
Non-trainable params: 44,583 (174.16 KB)
4.2 Hyperparameter Tuning¶
def create_model(learning_rate=0.001):
model = Sequential([
pretrained_base_EfficientNetB0,
BatchNormalization(),
GlobalAveragePooling2D(),
Dense(256, activation='relu'),
Dropout(0.5),
Dense(4, activation='softmax')
])
# Train the pretrained base
model.layers[0].trainable = True
optimizer = Adam(learning_rate=learning_rate)
model.compile(optimizer=optimizer, loss='sparse_categorical_crossentropy', metrics=['accuracy'])
return model
# Specify values for grid search
learning_rate_values = [0.001, 0.01, 0.1]
batch_size_values = [16, 32, 64]
# Tune learning rate and batch size
best_accuracy = 0
best_params = {}
for learning_rate in learning_rate_values:
for batch_size in batch_size_values:
# Create model
model = create_model(learning_rate=learning_rate)
# Train model
history = model.fit(x_train, y_train_int, epochs=5, batch_size=batch_size,
validation_data=(x_val, y_val_int))
# Evaluate model on validation set
accuracy = model.evaluate(x_val, y_val_int)[1]
# Print or store the results
print(f'Learning Rate: {learning_rate}, Batch Size: {batch_size}, Accuracy: {accuracy}')
# Update best parameters if needed
if accuracy > best_accuracy:
best_accuracy = accuracy
best_params.update({'learning_rate': learning_rate, 'batch_size': batch_size})
# Print the best hyperparameters
print(f'Best Hyperparameters: {best_params}')
Epoch 1/5 1000/1000 ━━━━━━━━━━━━━━━━━━━━ 219s 210ms/step - accuracy: 0.7238 - loss: 0.6999 - val_accuracy: 0.8860 - val_loss: 0.3142 Epoch 2/5 1000/1000 ━━━━━━━━━━━━━━━━━━━━ 219s 219ms/step - accuracy: 0.9181 - loss: 0.2276 - val_accuracy: 0.9423 - val_loss: 0.1797 Epoch 3/5 1000/1000 ━━━━━━━━━━━━━━━━━━━━ 248s 248ms/step - accuracy: 0.9470 - loss: 0.1474 - val_accuracy: 0.9525 - val_loss: 0.1343 Epoch 4/5 1000/1000 ━━━━━━━━━━━━━━━━━━━━ 258s 258ms/step - accuracy: 0.9635 - loss: 0.1068 - val_accuracy: 0.9643 - val_loss: 0.1551 Epoch 5/5 1000/1000 ━━━━━━━━━━━━━━━━━━━━ 263s 263ms/step - accuracy: 0.9701 - loss: 0.0848 - val_accuracy: 0.9865 - val_loss: 0.0384 125/125 ━━━━━━━━━━━━━━━━━━━━ 14s 110ms/step - accuracy: 0.9868 - loss: 0.0417 Learning Rate: 0.001, Batch Size: 16, Accuracy: 0.9865000247955322 Epoch 1/5 500/500 ━━━━━━━━━━━━━━━━━━━━ 280s 543ms/step - accuracy: 0.9754 - loss: 0.0822 - val_accuracy: 0.9908 - val_loss: 0.0341 Epoch 2/5 500/500 ━━━━━━━━━━━━━━━━━━━━ 298s 596ms/step - accuracy: 0.9886 - loss: 0.0369 - val_accuracy: 0.9445 - val_loss: 0.2435 Epoch 3/5 500/500 ━━━━━━━━━━━━━━━━━━━━ 288s 577ms/step - accuracy: 0.9863 - loss: 0.0385 - val_accuracy: 0.9800 - val_loss: 0.0659 Epoch 4/5 500/500 ━━━━━━━━━━━━━━━━━━━━ 286s 572ms/step - accuracy: 0.9908 - loss: 0.0276 - val_accuracy: 0.9877 - val_loss: 0.0337 Epoch 5/5 500/500 ━━━━━━━━━━━━━━━━━━━━ 292s 584ms/step - accuracy: 0.9900 - loss: 0.0297 - val_accuracy: 0.9805 - val_loss: 0.0757 125/125 ━━━━━━━━━━━━━━━━━━━━ 14s 115ms/step - accuracy: 0.9782 - loss: 0.0929 Learning Rate: 0.001, Batch Size: 32, Accuracy: 0.9804999828338623 Epoch 1/5 250/250 ━━━━━━━━━━━━━━━━━━━━ 318s 1s/step - accuracy: 0.9761 - loss: 0.0747 - val_accuracy: 0.9793 - val_loss: 0.0632 Epoch 2/5 250/250 ━━━━━━━━━━━━━━━━━━━━ 309s 1s/step - accuracy: 0.9970 - loss: 0.0108 - val_accuracy: 0.9905 - val_loss: 0.0473 Epoch 3/5 250/250 ━━━━━━━━━━━━━━━━━━━━ 313s 1s/step - accuracy: 0.9962 - loss: 0.0126 - val_accuracy: 0.9872 - val_loss: 0.0461 Epoch 4/5 250/250 ━━━━━━━━━━━━━━━━━━━━ 315s 1s/step - accuracy: 0.9961 - loss: 0.0131 - val_accuracy: 0.9908 - val_loss: 0.0317 Epoch 5/5 250/250 ━━━━━━━━━━━━━━━━━━━━ 318s 1s/step - accuracy: 0.9931 - loss: 0.0201 - val_accuracy: 0.9942 - val_loss: 0.0245 125/125 ━━━━━━━━━━━━━━━━━━━━ 16s 129ms/step - accuracy: 0.9943 - loss: 0.0275 Learning Rate: 0.001, Batch Size: 64, Accuracy: 0.9942499995231628 Epoch 1/5 1000/1000 ━━━━━━━━━━━━━━━━━━━━ 314s 305ms/step - accuracy: 0.7094 - loss: 0.8119 - val_accuracy: 0.7600 - val_loss: 0.8053 Epoch 2/5 1000/1000 ━━━━━━━━━━━━━━━━━━━━ 305s 305ms/step - accuracy: 0.7453 - loss: 0.6406 - val_accuracy: 0.6455 - val_loss: 2.5110 Epoch 3/5 1000/1000 ━━━━━━━━━━━━━━━━━━━━ 301s 301ms/step - accuracy: 0.7714 - loss: 0.5759 - val_accuracy: 0.5502 - val_loss: 1.1985 Epoch 4/5 1000/1000 ━━━━━━━━━━━━━━━━━━━━ 314s 314ms/step - accuracy: 0.7583 - loss: 0.6143 - val_accuracy: 0.6505 - val_loss: 3.9566 Epoch 5/5 1000/1000 ━━━━━━━━━━━━━━━━━━━━ 316s 316ms/step - accuracy: 0.7883 - loss: 0.5359 - val_accuracy: 0.8050 - val_loss: 0.6218 125/125 ━━━━━━━━━━━━━━━━━━━━ 17s 127ms/step - accuracy: 0.8014 - loss: 0.6807 Learning Rate: 0.01, Batch Size: 16, Accuracy: 0.8050000071525574 Epoch 1/5 500/500 ━━━━━━━━━━━━━━━━━━━━ 331s 642ms/step - accuracy: 0.8669 - loss: 0.3435 - val_accuracy: 0.8415 - val_loss: 0.6684 Epoch 2/5 500/500 ━━━━━━━━━━━━━━━━━━━━ 310s 619ms/step - accuracy: 0.9193 - loss: 0.2258 - val_accuracy: 0.8543 - val_loss: 0.6488 Epoch 3/5 500/500 ━━━━━━━━━━━━━━━━━━━━ 308s 616ms/step - accuracy: 0.9356 - loss: 0.1929 - val_accuracy: 0.9417 - val_loss: 0.1801 Epoch 4/5 500/500 ━━━━━━━━━━━━━━━━━━━━ 309s 619ms/step - accuracy: 0.9238 - loss: 0.2458 - val_accuracy: 0.7843 - val_loss: 1.3353 Epoch 5/5 500/500 ━━━━━━━━━━━━━━━━━━━━ 311s 622ms/step - accuracy: 0.8963 - loss: 0.3196 - val_accuracy: 0.9323 - val_loss: 0.2315 125/125 ━━━━━━━━━━━━━━━━━━━━ 16s 126ms/step - accuracy: 0.9270 - loss: 0.2121 Learning Rate: 0.01, Batch Size: 32, Accuracy: 0.9322500228881836 Epoch 1/5 250/250 ━━━━━━━━━━━━━━━━━━━━ 340s 1s/step - accuracy: 0.9442 - loss: 0.1517 - val_accuracy: 0.8723 - val_loss: 0.7233 Epoch 2/5 250/250 ━━━━━━━━━━━━━━━━━━━━ 337s 1s/step - accuracy: 0.9750 - loss: 0.0890 - val_accuracy: 0.9452 - val_loss: 0.2150 Epoch 3/5 250/250 ━━━━━━━━━━━━━━━━━━━━ 333s 1s/step - accuracy: 0.9699 - loss: 0.1031 - val_accuracy: 0.9762 - val_loss: 0.0802 Epoch 4/5 250/250 ━━━━━━━━━━━━━━━━━━━━ 333s 1s/step - accuracy: 0.9802 - loss: 0.0683 - val_accuracy: 0.9427 - val_loss: 0.3424 Epoch 5/5 250/250 ━━━━━━━━━━━━━━━━━━━━ 329s 1s/step - accuracy: 0.9879 - loss: 0.0460 - val_accuracy: 0.9815 - val_loss: 0.0591 125/125 ━━━━━━━━━━━━━━━━━━━━ 16s 125ms/step - accuracy: 0.9810 - loss: 0.0590 Learning Rate: 0.01, Batch Size: 64, Accuracy: 0.9815000295639038 Epoch 1/5 1000/1000 ━━━━━━━━━━━━━━━━━━━━ 304s 295ms/step - accuracy: 0.2687 - loss: 8.2124 - val_accuracy: 0.2488 - val_loss: 1.3889 Epoch 2/5 1000/1000 ━━━━━━━━━━━━━━━━━━━━ 292s 292ms/step - accuracy: 0.2504 - loss: 1.3962 - val_accuracy: 0.2488 - val_loss: 1.3894 Epoch 3/5 1000/1000 ━━━━━━━━━━━━━━━━━━━━ 300s 300ms/step - accuracy: 0.2469 - loss: 1.3960 - val_accuracy: 0.2442 - val_loss: 1.4168 Epoch 4/5 1000/1000 ━━━━━━━━━━━━━━━━━━━━ 299s 300ms/step - accuracy: 0.2639 - loss: 1.3949 - val_accuracy: 0.2488 - val_loss: 1.3909 Epoch 5/5 1000/1000 ━━━━━━━━━━━━━━━━━━━━ 296s 296ms/step - accuracy: 0.2607 - loss: 1.3940 - val_accuracy: 0.2510 - val_loss: 1.4167 125/125 ━━━━━━━━━━━━━━━━━━━━ 16s 120ms/step - accuracy: 0.2482 - loss: 1.4168 Learning Rate: 0.1, Batch Size: 16, Accuracy: 0.25099998712539673 Epoch 1/5 500/500 ━━━━━━━━━━━━━━━━━━━━ 304s 588ms/step - accuracy: 0.2549 - loss: 5.1203 - val_accuracy: 0.2510 - val_loss: 1.3918 Epoch 2/5 500/500 ━━━━━━━━━━━━━━━━━━━━ 293s 586ms/step - accuracy: 0.2498 - loss: 1.3921 - val_accuracy: 0.2442 - val_loss: 1.3912 Epoch 3/5 500/500 ━━━━━━━━━━━━━━━━━━━━ 293s 586ms/step - accuracy: 0.2436 - loss: 1.3939 - val_accuracy: 0.2442 - val_loss: 1.4099 Epoch 4/5 500/500 ━━━━━━━━━━━━━━━━━━━━ 296s 592ms/step - accuracy: 0.2533 - loss: 1.3947 - val_accuracy: 0.2442 - val_loss: 1.3944 Epoch 5/5 500/500 ━━━━━━━━━━━━━━━━━━━━ 289s 578ms/step - accuracy: 0.2578 - loss: 1.3948 - val_accuracy: 0.2442 - val_loss: 1.4106 125/125 ━━━━━━━━━━━━━━━━━━━━ 15s 116ms/step - accuracy: 0.2506 - loss: 1.4074 Learning Rate: 0.1, Batch Size: 32, Accuracy: 0.24424999952316284 Epoch 1/5 250/250 ━━━━━━━━━━━━━━━━━━━━ 322s 1s/step - accuracy: 0.2477 - loss: 7.3365 - val_accuracy: 0.2450 - val_loss: 1.3880 Epoch 2/5 250/250 ━━━━━━━━━━━━━━━━━━━━ 310s 1s/step - accuracy: 0.2486 - loss: 1.3891 - val_accuracy: 0.2512 - val_loss: 1.3892 Epoch 3/5 250/250 ━━━━━━━━━━━━━━━━━━━━ 310s 1s/step - accuracy: 0.2439 - loss: 1.3905 - val_accuracy: 0.2445 - val_loss: 1.3907 Epoch 4/5 250/250 ━━━━━━━━━━━━━━━━━━━━ 309s 1s/step - accuracy: 0.2454 - loss: 1.3905 - val_accuracy: 0.2562 - val_loss: 1.3892 Epoch 5/5 250/250 ━━━━━━━━━━━━━━━━━━━━ 308s 1s/step - accuracy: 0.2507 - loss: 1.3921 - val_accuracy: 0.2510 - val_loss: 1.3924 125/125 ━━━━━━━━━━━━━━━━━━━━ 15s 123ms/step - accuracy: 0.2482 - loss: 1.3922 Learning Rate: 0.1, Batch Size: 64, Accuracy: 0.25099998712539673 Best Hyperparameters: {'learning_rate': 0.001, 'batch_size': 64}
The following values resulted from this
- Learning rate: 0.001
- Batch size: 16
4.5 Training: Regular Classification¶
# Store model architecture in new object, to prevent confusion (e.g. with compiling, training and calculating test accuracy later)
# By cloning the model
model_EfficientNetB0_rc = clone_model(model_EfficientNetB0)
# Adam optimizer with learning rate 0.001
optimizer = Adam(learning_rate=0.01)
# Compile with ('regular') categorical_crossentropy
model_EfficientNetB0_rc.compile(optimizer=optimizer, loss='categorical_crossentropy', metrics=['accuracy'])
# Define early stopping callback
early_stopping = EarlyStopping(monitor='val_loss', patience=5, min_delta=0.0001, restore_best_weights=True)
# Train
history_EfficientNetB0_rc = model_EfficientNetB0_rc.fit(x_train, y_train, epochs=20, batch_size=64,
verbose=1, validation_data=(x_val, y_val), callbacks=[early_stopping])
Epoch 1/20 250/250 ━━━━━━━━━━━━━━━━━━━━ 330s 1s/step - accuracy: 0.3644 - loss: 2.7837 - val_accuracy: 0.2770 - val_loss: 2.1154 Epoch 2/20 250/250 ━━━━━━━━━━━━━━━━━━━━ 319s 1s/step - accuracy: 0.7169 - loss: 0.6119 - val_accuracy: 0.8562 - val_loss: 0.3913 Epoch 3/20 250/250 ━━━━━━━━━━━━━━━━━━━━ 318s 1s/step - accuracy: 0.8474 - loss: 0.3797 - val_accuracy: 0.8303 - val_loss: 0.5426 Epoch 4/20 250/250 ━━━━━━━━━━━━━━━━━━━━ 317s 1s/step - accuracy: 0.8991 - loss: 0.2708 - val_accuracy: 0.7943 - val_loss: 0.7014 Epoch 5/20 250/250 ━━━━━━━━━━━━━━━━━━━━ 319s 1s/step - accuracy: 0.9349 - loss: 0.1783 - val_accuracy: 0.7930 - val_loss: 0.7491 Epoch 6/20 250/250 ━━━━━━━━━━━━━━━━━━━━ 320s 1s/step - accuracy: 0.9241 - loss: 0.2264 - val_accuracy: 0.9128 - val_loss: 0.3463 Epoch 7/20 250/250 ━━━━━━━━━━━━━━━━━━━━ 317s 1s/step - accuracy: 0.9509 - loss: 0.1433 - val_accuracy: 0.9180 - val_loss: 0.3520 Epoch 8/20 250/250 ━━━━━━━━━━━━━━━━━━━━ 324s 1s/step - accuracy: 0.9641 - loss: 0.1008 - val_accuracy: 0.9540 - val_loss: 0.1698 Epoch 9/20 250/250 ━━━━━━━━━━━━━━━━━━━━ 317s 1s/step - accuracy: 0.9706 - loss: 0.0903 - val_accuracy: 0.8680 - val_loss: 0.5946 Epoch 10/20 250/250 ━━━━━━━━━━━━━━━━━━━━ 316s 1s/step - accuracy: 0.9663 - loss: 0.1103 - val_accuracy: 0.8755 - val_loss: 0.4972 Epoch 11/20 250/250 ━━━━━━━━━━━━━━━━━━━━ 311s 1s/step - accuracy: 0.9681 - loss: 0.0987 - val_accuracy: 0.8570 - val_loss: 0.7702 Epoch 12/20 250/250 ━━━━━━━━━━━━━━━━━━━━ 307s 1s/step - accuracy: 0.9720 - loss: 0.0951 - val_accuracy: 0.6192 - val_loss: 3.7637 Epoch 13/20 250/250 ━━━━━━━━━━━━━━━━━━━━ 314s 1s/step - accuracy: 0.9637 - loss: 0.1215 - val_accuracy: 0.9498 - val_loss: 0.1810
# Save the entire model
model_EfficientNetB0.save('saved_model/final_model.h5')
WARNING:absl:You are saving your model as an HDF5 file via `model.save()` or `keras.saving.save_model(model)`. This file format is considered legacy. We recommend using instead the native Keras format, e.g. `model.save('my_model.keras')` or `keras.saving.save_model(model, 'my_model.keras')`.
Results¶
To examine the performance of our model and compare ordinal with regular classification, we calculated several performance metrics on the validation and test data.
5. Results on Train and Validation Data¶
To have a first glance at the model's performance, we plotted the loss and accuracy of the validation data and inspected the convergence, for both the ordinal and regular classification.
5.1 Loss & Accuracy Regular Classification¶
# Loss & accuracy of regular model
history_EfficientNetB0_rc_frame = pd.DataFrame(history_EfficientNetB0_rc.history)
history_EfficientNetB0_rc_frame.loc[:, ['loss', 'val_loss']].plot()
history_EfficientNetB0_rc_frame.loc[:, ['accuracy', 'val_accuracy']].plot()
<Axes: >
We can see in the plot, the training loss has a constant decreasing trend, while the validation loss is fluctuating, especially at the end. It can be inferred that overfitting has occurred. We can see that there is a dramatic increase in accuracy between the training and validation sets, the only exception being a significant drop in accuracy at the last epoch. These results suggest that approaches like early stopping or regularization could be applied to reduce overfitting.
6. Results on Test Data: Regular Classification¶
6.1 Accuracy & Loss¶
# Evaluate on test set
testeval = model_EfficientNetB0_rc.evaluate(x_test, y_test, verbose=2)
# Print performance metrics (loss and accuracy)
print("Test Loss:", testeval[0])
print("Test Accuracy:", testeval[1])
64/64 - 7s - 111ms/step - accuracy: 0.9386 - loss: 0.2321 Test Loss: 0.23210683465003967 Test Accuracy: 0.9385530352592468
6.2 Confusion Matrix (Log Scale)¶
import numpy as np
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay, classification_report
import matplotlib.pyplot as plt
# Convert true labels to integer format
y_test_int = np.argmax(y_test, axis=1) # If y_test is one-hot encoded
# Predict probabilities and convert to integer labels
y_pred_probs = model_EfficientNetB0_rc.predict(x_test)
y_pred_rc_int = np.argmax(y_pred_probs, axis=1)
# Create confusion matrix
cm = confusion_matrix(y_test_int, y_pred_rc_int)
# Plot confusion matrix with numbers
plt.figure(figsize=(8, 6))
disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=np.unique(y_test_int))
# Display the confusion matrix
disp.plot(cmap=plt.cm.Blues, values_format='d', ax=plt.gca())
plt.title("Confusion Matrix with Counts and Labels")
plt.xlabel("Predicted Label")
plt.ylabel("True Label")
plt.show()
# Classification Report
print("Classification Report:")
print(classification_report(y_test_int, y_pred_rc_int))
64/64 ━━━━━━━━━━━━━━━━━━━━ 5s 78ms/step
Classification Report:
precision recall f1-score support
0 0.92 0.95 0.93 640
1 0.95 0.89 0.92 640
2 0.98 0.97 0.97 640
3 0.79 1.00 0.88 98
accuracy 0.94 2018
macro avg 0.91 0.95 0.93 2018
weighted avg 0.94 0.94 0.94 2018
6.3 F1 Score, Precision, Recall¶
y_pred_rc = to_categorical(y_pred_rc_int, num_classes=4)
precision, recall, f1, _ = precision_recall_fscore_support(y_test_int, y_pred_rc_int, average=None)
for i, (p, r, f) in enumerate(zip(precision, recall, f1)):
print(f'Class {i}: Precision = {p:.4f}, Recall = {r:.4f}, F1 Score = {f:.4f}')
micro_precision, micro_recall, micro_f1, _ = precision_recall_fscore_support(y_test, y_pred_rc, average='micro')
print("Micro-average Precision:", round(micro_precision, 4))
print("Micro-average Recall:", round(micro_recall, 4))
print("Micro-average F1 Score:", round(micro_f1, 4))
Class 0: Precision = 0.9186, Recall = 0.9516, F1 Score = 0.9348 Class 1: Precision = 0.9498, Recall = 0.8875, F1 Score = 0.9176 Class 2: Precision = 0.9779, Recall = 0.9672, F1 Score = 0.9725 Class 3: Precision = 0.7903, Recall = 1.0000, F1 Score = 0.8829 Micro-average Precision: 0.9386 Micro-average Recall: 0.9386 Micro-average F1 Score: 0.9386
6.4 ROC-AUC Score¶
from sklearn.metrics import roc_curve, auc
from itertools import cycle
# Get the probabilities and true labels
y_probs_rc = model_EfficientNetB0_rc.predict(x_test)
n_classes = y_probs_rc.shape[1]
# Compute ROC curve and AUC for each class
fpr = dict()
tpr = dict()
roc_auc = dict()
for i in range(n_classes):
fpr[i], tpr[i], _ = roc_curve(y_test_int == i, y_probs_rc[:, i])
roc_auc[i] = auc(fpr[i], tpr[i])
# Compute micro-average ROC curve and ROC AUC
fpr["micro"], tpr["micro"], _ = roc_curve(y_test.ravel(), y_probs_rc.ravel())
roc_auc["micro"] = auc(fpr["micro"], tpr["micro"])
# Plot all ROC curves
plt.figure(figsize=(10, 8))
colors = cycle(['aqua', 'darkorange', 'cornflowerblue', 'green'])
for i, color in zip(range(n_classes), colors):
plt.plot(fpr[i], tpr[i], color=color, lw=2,
label=f'Class {i} (AUC = {roc_auc[i]:.2f})')
# Plot micro-average ROC curve
plt.plot(fpr["micro"], tpr["micro"],
label=f'Micro-average (AUC = {roc_auc["micro"]:.2f})',
color='deeppink', linestyle=':', linewidth=4)
# Plot diagonal
plt.plot([0, 1], [0, 1], 'k--', lw=2)
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('False Positive Rate', fontsize=12)
plt.ylabel('True Positive Rate', fontsize=12)
plt.title('Multi-class ROC Curves', fontsize=16)
plt.legend(loc="lower right", fontsize=10)
plt.grid(alpha=0.3)
plt.show()
64/64 ━━━━━━━━━━━━━━━━━━━━ 5s 81ms/step
7. Basic Model¶
We formulated a "basic" model in which we specified convolution, batch normalization, pooling, dropout, flatten and dense layers ourselves. As with our final model, this model was trained for ordinal and regular classification. Accuracy, loss and Scott's pi were calculated on the test data.
7.1 Model Architecture¶
# Model architecture
model = Sequential()
model.add(Conv2D(32, kernel_size=(2, 2), input_shape=(128, 128, 3), padding='Same'))
model.add(Conv2D(32, kernel_size=(2, 2), activation='relu', padding='Same'))
model.add(BatchNormalization())
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Dropout(0.25))
model.add(Conv2D(64, kernel_size=(2, 2), activation='relu', padding='Same'))
model.add(Conv2D(64, kernel_size=(2, 2), activation='relu', padding='Same'))
model.add(BatchNormalization())
model.add(MaxPooling2D(pool_size=(2, 2), strides=(2, 2)))
model.add(Dropout(0.25))
model.add(Flatten())
model.add(Dense(512, activation='relu'))
model.add(Dropout(0.5))
model.add(Dense(4,activation='softmax'))
# Summary of the model
model.summary()
/opt/anaconda3/lib/python3.12/site-packages/keras/src/layers/convolutional/base_conv.py:107: UserWarning: Do not pass an `input_shape`/`input_dim` argument to a layer. When using Sequential models, prefer using an `Input(shape)` object as the first layer in the model instead.
Model: "sequential_10"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓ ┃ Layer (type) ┃ Output Shape ┃ Param # ┃ ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩ │ conv2d (Conv2D) │ (None, 128, 128, 32) │ 416 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ conv2d_1 (Conv2D) │ (None, 128, 128, 32) │ 4,128 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ batch_normalization_10 │ (None, 128, 128, 32) │ 128 │ │ (BatchNormalization) │ │ │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ max_pooling2d (MaxPooling2D) │ (None, 64, 64, 32) │ 0 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ dropout_10 (Dropout) │ (None, 64, 64, 32) │ 0 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ conv2d_2 (Conv2D) │ (None, 64, 64, 64) │ 8,256 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ conv2d_3 (Conv2D) │ (None, 64, 64, 64) │ 16,448 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ batch_normalization_11 │ (None, 64, 64, 64) │ 256 │ │ (BatchNormalization) │ │ │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ max_pooling2d_1 (MaxPooling2D) │ (None, 32, 32, 64) │ 0 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ dropout_11 (Dropout) │ (None, 32, 32, 64) │ 0 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ flatten (Flatten) │ (None, 65536) │ 0 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ dense_20 (Dense) │ (None, 512) │ 33,554,944 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ dropout_12 (Dropout) │ (None, 512) │ 0 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ dense_21 (Dense) │ (None, 4) │ 2,052 │ └─────────────────────────────────┴────────────────────────┴───────────────┘
Total params: 33,586,628 (128.12 MB)
Trainable params: 33,586,436 (128.12 MB)
Non-trainable params: 192 (768.00 B)
7.2 Training for Regular Classification¶
# Store model architecture in new object, to prevent confusion (e.g. with compiling, training and calculating test accuracy later)
# By cloning the model
model_rc = clone_model(model)
# Compile with ('regular') categorical_crossentropy
model_rc.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
# Train the model
history_rc = model_rc.fit(x_train, y_train, epochs=10, batch_size=16,
verbose=1, validation_data=(x_val, y_val))
Epoch 1/10 1000/1000 ━━━━━━━━━━━━━━━━━━━━ 202s 202ms/step - accuracy: 0.6062 - loss: 4.3680 - val_accuracy: 0.8570 - val_loss: 0.3280 Epoch 2/10 1000/1000 ━━━━━━━━━━━━━━━━━━━━ 214s 214ms/step - accuracy: 0.8216 - loss: 0.4444 - val_accuracy: 0.9072 - val_loss: 0.2152 Epoch 3/10 1000/1000 ━━━━━━━━━━━━━━━━━━━━ 226s 226ms/step - accuracy: 0.8607 - loss: 0.3567 - val_accuracy: 0.9170 - val_loss: 0.2206 Epoch 4/10 1000/1000 ━━━━━━━━━━━━━━━━━━━━ 232s 232ms/step - accuracy: 0.8926 - loss: 0.2837 - val_accuracy: 0.8880 - val_loss: 0.3123 Epoch 5/10 1000/1000 ━━━━━━━━━━━━━━━━━━━━ 235s 235ms/step - accuracy: 0.8950 - loss: 0.3049 - val_accuracy: 0.9350 - val_loss: 0.1604 Epoch 6/10 1000/1000 ━━━━━━━━━━━━━━━━━━━━ 250s 250ms/step - accuracy: 0.9131 - loss: 0.2148 - val_accuracy: 0.9575 - val_loss: 0.1020 Epoch 7/10 1000/1000 ━━━━━━━━━━━━━━━━━━━━ 267s 267ms/step - accuracy: 0.9274 - loss: 0.2232 - val_accuracy: 0.9730 - val_loss: 0.0692 Epoch 8/10 1000/1000 ━━━━━━━━━━━━━━━━━━━━ 267s 267ms/step - accuracy: 0.9386 - loss: 0.1667 - val_accuracy: 0.9735 - val_loss: 0.0593 Epoch 9/10 1000/1000 ━━━━━━━━━━━━━━━━━━━━ 254s 254ms/step - accuracy: 0.9465 - loss: 0.1486 - val_accuracy: 0.9715 - val_loss: 0.0649 Epoch 10/10 1000/1000 ━━━━━━━━━━━━━━━━━━━━ 244s 244ms/step - accuracy: 0.9391 - loss: 0.1776 - val_accuracy: 0.9760 - val_loss: 0.0730
7.3 Results on Test Set¶
7.3.1 Accuracy & Loss¶
# Evaluate on test set
testeval = model_rc.evaluate(x_test, y_test, verbose=2)
# Printing performance metrics (loss and accuracy)
print("Test Loss:", testeval[0])
print("Test Accuracy:", testeval[1])
64/64 - 3s - 49ms/step - accuracy: 0.9658 - loss: 0.0930 Test Loss: 0.09298207610845566 Test Accuracy: 0.9658077359199524
import numpy as np
from sklearn.metrics import confusion_matrix
import matplotlib.pyplot as plt
# Convert y_val to integer labels if needed
y_true = np.argmax(y_val, axis=1) # Assuming y_val is one-hot encoded
# Predict probabilities and convert to class labels
y_pred_probs = model_rc.predict(x_val)
y_pred = np.argmax(y_pred_probs, axis=1)
# Confusion Matrix
cm = confusion_matrix(y_true, y_pred)
plt.figure(figsize=(8, 6))
plt.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues)
plt.title('Confusion Matrix')
plt.colorbar()
tick_marks = np.arange(len(np.unique(y_true)))
plt.xticks(tick_marks, np.unique(y_true), rotation=45)
plt.yticks(tick_marks, np.unique(y_true))
plt.xlabel('Predicted Label')
plt.ylabel('True Label')
# Display numbers in the confusion matrix
for i in range(len(cm)):
for j in range(len(cm[i])):
plt.text(j, i, cm[i, j], horizontalalignment='center',
color='white' if cm[i, j] > cm.max() / 2 else 'black')
plt.show()
125/125 ━━━━━━━━━━━━━━━━━━━━ 7s 52ms/step
from sklearn.metrics import classification_report
print("Classification Report:")
print(classification_report(y_true, y_pred))
Classification Report:
precision recall f1-score support
0 0.99 0.91 0.95 995
1 0.93 0.99 0.96 1024
2 0.99 1.00 0.99 1004
3 1.00 1.00 1.00 977
accuracy 0.98 4000
macro avg 0.98 0.98 0.98 4000
weighted avg 0.98 0.98 0.98 4000
# Ensure x_test and y_test are consistent
print("x_test shape:", x_test.shape)
print("y_test shape:", y_test.shape)
# Predict probabilities and ensure shapes are consistent
y_pred_probs = model_EfficientNetB0_rc.predict(x_test)
print("y_pred_probs shape:", y_pred_probs.shape)
# Convert y_test to integer labels if one-hot encoded
if len(y_test.shape) > 1:
y_true = np.argmax(y_test, axis=1)
else:
y_true = y_test
# Convert predicted probabilities to predicted labels
y_pred = np.argmax(y_pred_probs, axis=1)
# Verify matching lengths
print("y_true shape:", y_true.shape)
print("y_pred shape:", y_pred.shape)
assert len(y_true) == len(y_pred), "Mismatch between y_true and y_pred lengths!"
x_test shape: (2018, 128, 128, 3) y_test shape: (2018, 4) 64/64 ━━━━━━━━━━━━━━━━━━━━ 5s 84ms/step y_pred_probs shape: (2018, 4) y_true shape: (2018,) y_pred shape: (2018,)
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import precision_score, recall_score, f1_score, roc_curve, auc
from sklearn.preprocessing import label_binarize
from itertools import cycle
# Assuming y_true and y_pred are already defined
# For example:
# y_true = np.argmax(y_test, axis=1) # Convert one-hot to integer labels
# y_pred_probs = model.predict(x_test)
# y_pred = np.argmax(y_pred_probs, axis=1)
# Calculate F1 Score, Precision, and Recall
f1 = f1_score(y_true, y_pred, average='weighted')
precision = precision_score(y_true, y_pred, average='weighted')
recall = recall_score(y_true, y_pred, average='weighted')
print(f"F1 Score: {f1:.4f}")
print(f"Precision: {precision:.4f}")
print(f"Recall: {recall:.4f}")
# Multi-class ROC-AUC Graph
n_classes = len(np.unique(y_true)) # Number of classes
y_true_binarized = label_binarize(y_true, classes=np.arange(n_classes)) # Convert to binary format
# Compute ROC curve and AUC for each class
fpr = dict()
tpr = dict()
roc_auc = dict()
for i in range(n_classes):
fpr[i], tpr[i], _ = roc_curve(y_true_binarized[:, i], y_pred_probs[:, i])
roc_auc[i] = auc(fpr[i], tpr[i])
# Compute micro-average ROC curve and ROC AUC
fpr["micro"], tpr["micro"], _ = roc_curve(y_true_binarized.ravel(), y_pred_probs.ravel())
roc_auc["micro"] = auc(fpr["micro"], tpr["micro"])
# Plot ROC-AUC curves
plt.figure(figsize=(10, 8))
colors = cycle(['aqua', 'darkorange', 'cornflowerblue', 'green'])
for i, color in zip(range(n_classes), colors):
plt.plot(fpr[i], tpr[i], color=color, lw=2,
label=f'Class {i} (AUC = {roc_auc[i]:.2f})')
# Plot micro-average ROC curve
plt.plot(fpr["micro"], tpr["micro"],
label=f'Micro-average (AUC = {roc_auc["micro"]:.2f})',
color='deeppink', linestyle=':', linewidth=4)
# Plot diagonal
plt.plot([0, 1], [0, 1], 'k--', lw=2)
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('False Positive Rate', fontsize=12)
plt.ylabel('True Positive Rate', fontsize=12)
plt.title('Multi-class ROC Curves', fontsize=16)
plt.legend(loc="lower right", fontsize=10)
plt.grid(alpha=0.3)
plt.show()
F1 Score: 0.9388 Precision: 0.9411 Recall: 0.9386
import matplotlib.pyplot as plt
# Data from the training process
epochs = range(1, 11) # 10 epochs
train_accuracy = [0.6062, 0.8216, 0.8607, 0.8926, 0.8950, 0.9131, 0.9274, 0.9386, 0.9465, 0.9391]
val_accuracy = [0.8570, 0.9072, 0.9170, 0.8880, 0.9350, 0.9575, 0.9730, 0.9735, 0.9715, 0.9760]
# Find highest and lowest values
train_max = max(train_accuracy)
train_min = min(train_accuracy)
val_max = max(val_accuracy)
val_min = min(val_accuracy)
# Plot the training and validation accuracy
plt.figure(figsize=(12, 8))
plt.plot(
epochs,
train_accuracy,
label='Training Accuracy',
color='blue',
linewidth=2,
marker='o',
markersize=6,
markerfacecolor='yellow',
markeredgecolor='black'
)
plt.plot(
epochs,
val_accuracy,
label='Validation Accuracy',
color='green',
linewidth=2,
marker='s',
markersize=6,
markerfacecolor='lime',
markeredgecolor='black',
linestyle='--'
)
# Mark all data points
for i, (train, val) in enumerate(zip(train_accuracy, val_accuracy), start=1):
plt.text(i, train, f'{train:.4f}', fontsize=10, ha='right', va='bottom', color='blue')
plt.text(i, val, f'{val:.4f}', fontsize=10, ha='left', va='bottom', color='green')
# Highlight highest and lowest points
plt.scatter([train_accuracy.index(train_max) + 1], [train_max], color='red', s=100, label='Highest Training Accuracy')
plt.scatter([train_accuracy.index(train_min) + 1], [train_min], color='orange', s=100, label='Lowest Training Accuracy')
plt.scatter([val_accuracy.index(val_max) + 1], [val_max], color='red', s=100, label='Highest Validation Accuracy', marker='D')
plt.scatter([val_accuracy.index(val_min) + 1], [val_min], color='orange', s=100, label='Lowest Validation Accuracy', marker='D')
# Add labels, title, and legend
plt.title('Training vs Validation Accuracy Over Epochs', fontsize=18, color='darkblue')
plt.xlabel('Epochs', fontsize=14, color='darkred')
plt.ylabel('Accuracy', fontsize=14, color='darkred')
plt.xticks(epochs, fontsize=12, color='brown')
plt.yticks(fontsize=12, color='brown')
plt.legend(fontsize=12, loc='lower right')
# Add grid
plt.grid(color='gray', linestyle='--', linewidth=0.5, alpha=0.5)
# Display the plot
plt.tight_layout()
plt.show()
8. Comparison between both models¶
we used EfficientNetB0 and custom CNN model for classification
8.1 Visualization of Classification Metrics¶
import matplotlib.pyplot as plt
import numpy as np
# Metrics from Model 1 (EfficientNetB0)
metrics_model1 = {
"Precision": 0.94,
"Recall": 0.94,
"F1-Score": 0.94,
"Accuracy": 0.94
}
# Metrics from Model 2 (Custom CNN)
metrics_model2 = {
"Precision": 0.98,
"Recall": 0.98,
"F1-Score": 0.98,
"Accuracy": 0.98
}
# Visualization
labels = list(metrics_model1.keys())
model1_values = list(metrics_model1.values())
model2_values = list(metrics_model2.values())
x = np.arange(len(labels)) # Label locations
width = 0.35 # Bar width
fig, ax = plt.subplots(figsize=(8, 6))
rects1 = ax.bar(x - width/2, model1_values, width, label='Model 1 (EfficientNetB0)')
rects2 = ax.bar(x + width/2, model2_values, width, label='Model 2 (Custom CNN)')
# Add text for labels, title, and custom x-axis tick labels
ax.set_xlabel('Metrics', fontsize=12)
ax.set_ylabel('Scores', fontsize=12)
ax.set_title('Comparison of Classification Metrics', fontsize=16)
ax.set_xticks(x)
ax.set_xticklabels(labels)
ax.legend()
# Annotate bars
for rects in [rects1, rects2]:
for rect in rects:
height = rect.get_height()
ax.annotate(f'{height:.2f}',
xy=(rect.get_x() + rect.get_width() / 2, height),
xytext=(0, 3), # 3 points vertical offset
textcoords="offset points",
ha='center', va='bottom')
plt.tight_layout()
plt.show()
8.2 Confusion Matrix Summary Comparison¶
# Correct predictions from confusion matrices
correct_model1 = [609, 568, 619, 98] # Diagonal of Model 1 confusion matrix
correct_model2 = [906, 1017, 1004, 977] # Diagonal of Model 2 confusion matrix
labels = ["Class 0", "Class 1", "Class 2", "Class 3"]
x = np.arange(len(labels)) # Label locations
width = 0.35 # Bar width
fig, ax = plt.subplots(figsize=(8, 6))
rects1 = ax.bar(x - width/2, correct_model1, width, label='Model 1 (EfficientNetB0)')
rects2 = ax.bar(x + width/2, correct_model2, width, label='Model 2 (Custom CNN)')
# Add text for labels, title, and custom x-axis tick labels
ax.set_xlabel('Classes', fontsize=12)
ax.set_ylabel('Correct Predictions', fontsize=12)
ax.set_title('Comparison of Correct Predictions by Class', fontsize=16)
ax.set_xticks(x)
ax.set_xticklabels(labels)
ax.legend()
# Annotate bars
for rects in [rects1, rects2]:
for rect in rects:
height = rect.get_height()
ax.annotate(f'{height}',
xy=(rect.get_x() + rect.get_width() / 2, height),
xytext=(0, 3), # 3 points vertical offset
textcoords="offset points",
ha='center', va='bottom')
plt.tight_layout()
plt.show()
8.3 ROC-AUC Score Comparison¶
from itertools import cycle
# ROC-AUC Scores for Model 1 and Model 2
roc_auc_model1 = [0.99, 0.99, 1.00, 1.00, 0.99] # Include micro-average at the end
roc_auc_model2 = [0.95, 0.96, 0.99, 1.00, 0.98]
labels = ["Class 0", "Class 1", "Class 2", "Class 3", "Micro-average"]
colors = cycle(['aqua', 'darkorange', 'cornflowerblue', 'green', 'deeppink'])
x = np.arange(len(labels)) # Label locations
width = 0.35 # Bar width
fig, ax = plt.subplots(figsize=(10, 6))
rects1 = ax.bar(x - width/2, roc_auc_model1, width, label='Model 1 (EfficientNetB0)', color='skyblue')
rects2 = ax.bar(x + width/2, roc_auc_model2, width, label='Model 2 (Custom CNN)', color='lightcoral')
# Add text for labels, title, and custom x-axis tick labels
ax.set_xlabel('Classes', fontsize=12)
ax.set_ylabel('ROC-AUC Scores', fontsize=12)
ax.set_title('Comparison of ROC-AUC Scores by Class', fontsize=16)
ax.set_xticks(x)
ax.set_xticklabels(labels)
ax.legend()
# Annotate bars
for rects in [rects1, rects2]:
for rect in rects:
height = rect.get_height()
ax.annotate(f'{height:.2f}',
xy=(rect.get_x() + rect.get_width() / 2, height),
xytext=(0, 3), # 3 points vertical offset
textcoords="offset points",
ha='center', va='bottom')
plt.tight_layout()
plt.show()
9. Conclusion¶
The evaluation of the models demonstrates that while EfficientNetB0 offers the advantages of transfer learning with faster training, computational efficiency, and strong generalization, the custom-built CNN emerged as the superior choice for this specific task. The custom CNN delivered higher overall accuracy, precision, recall, and F1-scores, consistently outperforming in capturing class-specific nuances and achieving robust performance on minority classes.
Although more computationally intensive and requiring a longer training time, the custom CNN's ability to learn intricate patterns from scratch makes it the preferred model for applications where high accuracy, precision, and recall are critical. Its flexibility and adaptability to the dataset allowed it to outperform a pretrained model, making it ideal for use cases that demand detailed insights and tailored solutions.
By prioritizing performance and reliability over speed and pre-trained efficiency, the custom CNN demonstrated its capability as the optimal choice for achieving high-quality results in multi-class classification. Its robustness ensures it can be further optimized and adapted to a wide range of real-world challenges.